# iqa_tool_worker.py
import sys
import json
import torch
from PIL import Image
from io import BytesIO
import base64
import pyiqa
import torchvision.transforms as transforms

with open('/root/IQA/IQA-Agent/iqa_models_results/model_fitting_result.json') as f:
    model_data_params = json.load(f)

def logistic(model_name, X):
    beta = torch.tensor(model_data_params[model_name]['beta'], dtype=X.dtype, device=X.device)
    beta1, beta2, beta3, beta4, beta5 = beta
    part = 0.5 - 1.0 / (1 + torch.exp(beta2 * (X - beta3)))
    return beta1 * part + beta4 * X + beta5

def decode_image(image_input: str, device="cuda"):
    if image_input.startswith("data:image"):
        image_input = image_input.split(",")[-1]
    image_data = base64.b64decode(image_input)
    img = Image.open(BytesIO(image_data)).convert("RGB")
    img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
    return img_tensor

def main():
    payload = json.load(sys.stdin)
    model_name = payload["model_name"]
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = pyiqa.create_metric(model_name, device=device).to(device)
    
    if "reference_image" in payload and "distorted_image" in payload:
        ref_tensor = decode_image(payload["reference_image"], device)
        dist_tensor = decode_image(payload["distorted_image"], device)
        with torch.no_grad():
            score = model(ref_tensor, dist_tensor)
    else:
        img_tensor = decode_image(payload["image"], device)
        with torch.no_grad():
            score = model(img_tensor)

    score = logistic(model_name, score)
    print(json.dumps({"score": score.item()}))

if __name__ == "__main__":
    main()
